LlamaIndexのコードに潜って、マルチモーダル対応の実装を確認してきたのでシェアする
こんちには。
データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。
LlamaIndexはマルチモーダルにも対応しているのですね!以下のブログを見るまで私も知りませんでした!
BLIPなど画像を理解してベクトル化するモデルは、GPUを使わないと推論を動かすことは難しい印象が私にはありました。
しかしこの記事ではCPU実行できちんと動作しているようです。いったいどのように実現しているのでしょうか?
少し掘り下げてみてみましょう。(結論だけを見たい方はまとめをご覧ください)
画像はImageParserで処理している
pngを処理していそうな箇所を検索してみると、拡張子毎にParserが定義されている箇所を見つけました。
これを見たところ、pngやjpgはImageParserで処理しているようです。
ImageParserはHugging Faceのモデルを使用
ImageParserのコードを見たところ、いくつか依存関係のあるライブラリがあります。
すくなくとも以下のコードあたりは依存関係がありそうです。
import torch from transformers import DonutProcessor, VisionEncoderDecoderModel import sentencepiece from PIL import Image
さらに、DonutProcessor
とVisionEncoderDecoderModel
については、以下のようになっています。
processor = DonutProcessor.from_pretrained( "naver-clova-ix/donut-base-finetuned-cord-v2" ) model = VisionEncoderDecoderModel.from_pretrained( "naver-clova-ix/donut-base-finetuned-cord-v2" )
このことから、Hugging Faceに公開されている以下を使っていると考えて良さそうです。
donut-base-finetuned-cord-v2とは
上記のリンクの説明にあるとおり、CORDでfine-tuningされたDonut Modelと記載されており、「Donut」はDocument understanding transformerの頭文字です。
このDonut ModelはOCRフリーであることを目指したVDU(Visual Document Understanding)モデルとなっていてます。
要するに前処理にOCRを使い、それを元に文書理解をするのではなく、文字抽出と文書理解をEnd-to-Endで解く様なモデルを学習していると考えて良さそうです。
おもに画像用のエンコーダ(Swin Transformerを使用)とテキスト用のデコーダ(BARTを使用)から構成されています。
詳細は以下の元論文も参照ください。
せっかくなのでDonutを動かしてみる
Hugging Faceの以下にサンプルがいくつかあります。
解くことが可能なタスクとして以下が上げられています。
- Document Image Classification
- Document Parsing
- Document Visual Question Answering (DocVQA)
各タスクはスペシャルトークンをデコーダの先頭に入れることで、タスクを認識させているようです。
ImageParserで実行されているのはDocument Parsingに該当する、<s>
というトークンが与えられています。
ですので、LlamaIndexでも使用されているDocument Parsingを試してみます。
動かしてみた
使用環境
Google Colaboratoryを使います。ハードウェアアクセラレータは無し、ラインタイム仕様も標準です。
以下をインストールします。
!pip install transformers !pip install datasets !pip install sentencepiece
その後確認した、主なバージョン情報は以下です。
!python --version
Python 3.9.16
!pip freeze | grep \ -e "transformers" -e "torch @" -e "sentencepiece" -e "Pillow" \ -e "^datasets==" -e "tokenizers"
datasets==2.10.1 Pillow==8.4.0 sentencepiece==0.1.97 tokenizers==0.13.2 torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl transformers==4.27.2
準備
インポートは以下とします。
from transformers import DonutProcessor, VisionEncoderDecoderModel from datasets import load_dataset import torch import re from PIL import Image import numpy as np
モデル取得
モデルはHugging Faceから、特に認証ナシに取得できます。
processor = DonutProcessor\ .from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") model = VisionEncoderDecoderModel\ .from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device)
サンプルデータ取得
サンプルデータも取得することができます。
dataset = load_dataset("hf-internal-testing/example-documents", split="test") image = dataset[2]["image"]
処理する画像は以下のようなレシートになります。
image
モデルによるパース(推論処理)
タスクを識別するスペシャルトークンを、数値IDに変換しておきます。
task_prompt = "<s_cord-v2>" decoder_input_ids = processor.tokenizer( task_prompt , add_special_tokens=False , return_tensors="pt").input_ids decoder_input_ids
tensor([[57579]])
画像の前処理を行うprocessorで処理します。処理の中身は割愛しますが、出力としては決まったサイズのtorch.tensorが得られるようです。
pixel_values = processor(image, return_tensors="pt").pixel_values pixel_values.shape
torch.Size([1, 3, 1280, 960])
モデルで推論結果を得ます。
outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) outputs
GreedySearchEncoderDecoderOutput(sequences=tensor([[57579, 57526, 57528, 40920, 57527, 57560, 40474, 35815, 48845, 486, 57559, 57530, 34751, 57529, 57532, 38873, 35815, 486, 42438, 57531, 57522, 57528, 46040, 24277, 57527, 57532, 38873, 35815, 486, 12569, 57531, 57525, 57544, 57546, 38100, 35815, 38921, 57545, 57550, 38100, 35815, 28268, 43112, 57549, 57543, 2]]), scores=None, encoder_attentions=None, encoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, decoder_hidden_states=None)
処理時間を計測したところ、画像1枚で20秒程度となっており、CPUでも処理できそうな処理量となっていそうです。
得られた出力は、まだ数値ID列のままであるため、デコードして読むことが可能なデータに戻します。
sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "")\ .replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token print(processor.token2json(sequence))
{'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total': {'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}}
記述されているレシートの内容がJSON形式(辞書形式)で取得できていることが分かりました。
以降のために、推論に必要な処理を関数化しておきます。
def parse(image: Image) -> str: pixel_values = processor(image, return_tensors="pt").pixel_values outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "")\ .replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token return processor.token2json(sequence)
日本語対応は?
日本語対応ができるのか、以下のような画像データをjinko_ja.png
として作成して試してみます。
Image.open
がlazy operationであるため、今回は一旦複製して実体化させます。
with open("jinko_ja.png", "rb") as f: image = Image.open(f) image = image.copy() np.array(image).shape
(423, 539, 4)
フォーマットが4チャンネル(RGBA)となっているようだったので以下で3チャンネル(RGB)にします。
r = image.getchannel('R') g = image.getchannel('G') b = image.getchannel('B') image = Image.merge('RGB', (r, g, b)) np.array(image).shape
これらの処理もまとめて関数化しておきます。
def load(image_file: str) -> Image: with open(image_file, "rb") as f: image = Image.open(f) r = image.getchannel('R') g = image.getchannel('G') b = image.getchannel('B') image = Image.merge('RGB', (r, g, b)) return image
では一連の処理をして推論結果を得てみましょう。
image = load("jinko_ja.png") result = parse(image) print(result)
{'menu': {'nm': '東京', 'price': '12,758'}, 'sub_total': {'subtotal_price': '8,880', 'tax_price': '8,812'}, 'total': {'total_price': '7,360', 'creditcardprice': '7,090'}}
漢字自体は読み込めそうな動きをしていそうですが、一部のみとなっており、存在しない値も抽出していそうです。
以下のように英語版を使用するとどうでしょうか?ファイルをjinko_en.png
として保存して試してみましょう。
image = load("jinko_en.png") result = parse(image) print(result)
{'menu': [{'nm': 'population', 'unitprice': '12,758', 'cnt': '了', 'price': '8,880'}, {'nm': 'Kanagawa', 'price': '8,812'}], 'total': {'total_price': '7,360', 'cashprice': '7,090'}}
日本語と同様に、書いていないキーの情報が検出されています。推測ですが、CORDのデータセットにfine tuningされているためと考えられます。
ですので、現状のままで任意の用途に使用することは難しく、ある程度用途を絞ったデータセットで学習して使う必要がありそうですね。
まとめ
いかがでしたでしょうか。
LlamaIndexのマルチモーダル対応の中身を掘っていったら、気付いたらDonutというVDU(視覚文書理解)モデルを動かしてしまいました。
その結果、LlamdaIndexは現在のところ画像を直接的に埋め込みベクトルには変換せず、単純に記載されている文字データとしてパースしているようです。 ですので、猫の画像をいれても、猫として認識させることは現状難しいということになります。(猫という文字ならOK)
またある特定のデータセットでfine-tuningされている点もポイントとして挙げられます。
そのままでは一般的な用途に使うことは難しいため、きちんと使用する場合は、ある程度用途を絞ったデータセットでfine-tuningするなどが必要ということが分かったことも収穫でした。
GPT-4などはBLIPなどと同様に、画像自体を埋め込みベクトルに変換可能な、マルチモーダルな入力を受け付ける大規模言語モデル(LLM)と認識しています。
これらのAPIがOpenAIなどによって公開されれば、LlamdaIndexにもこれらが取り込まれ、いずれは画像自体を理解して処理してくれるのでは期待しています。
本記事が、今後LlamaIndexをお使いになられる方の参考になれば幸いです。